/* *****************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package facewap;


import face3wap.Img3GanGraph;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author WangFeng
 */
public class GanModel {

    private static final Logger log = LoggerFactory.getLogger(GanModel.class);
    public static final int seed = 42;
    private static final double LEARNING_RATE = 0.0002;
    private static final double GRADIENT_THRESHOLD = 100.0;
    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();

    public static ComputationGraph createEncoder(int ncIn,int inputSize,int batch) {
        ncIn= ncIn == 0 ?3 : ncIn;
        inputSize= inputSize == 0 ?64 : inputSize;

        ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .updater(UPDATER)
                .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                .gradientNormalizationThreshold(GRADIENT_THRESHOLD)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.IDENTITY)
                .graphBuilder()
                .addInputs("trainFeatures","input2")
                .setInputTypes(InputType.convolutional(inputSize, inputSize, ncIn,CNN2DFormat.NHWC),InputType.convolutional(inputSize, inputSize, ncIn+1,CNN2DFormat.NCHW))
                //图片过滤
                .addLayer("g1", new ConvolutionLayer.Builder().kernelSize(5,5).convolutionMode(ConvolutionMode.Same).hasBias(false)
                       .nOut(64).build(), "trainFeatures")
                .addLayer("g2", new ConvolutionLayer.Builder().kernelSize(3,3).convolutionMode(ConvolutionMode.Same).stride(2,2).hasBias(false)
                        .nOut(128).activation(Activation.RELU).build(), "g1")
                .addLayer("g3", new ConvolutionLayer.Builder().kernelSize(3,3).convolutionMode(ConvolutionMode.Same).stride(2,2).hasBias(false)
                        .nOut(256).activation(Activation.RELU).build(), "g2")
                .addLayer("g4", new ConvolutionLayer.Builder().kernelSize(3,3).convolutionMode(ConvolutionMode.Same).stride(2,2).hasBias(false)
                        .nOut(512).activation(Activation.RELU).build(), "g3")
                .addLayer("g5", new ConvolutionLayer.Builder().kernelSize(3,3).convolutionMode(ConvolutionMode.Same).stride(2,2).hasBias(false)
                        .nOut(1024).activation(Activation.RELU).build(), "g4")
               // .addVertex("Flatten", new PreprocessorVertex(), "cnn5")
                .addLayer("g6", new DenseLayer.Builder().nOut(1024)
                        .build(), "g5")
                .addLayer("g7", new DenseLayer.Builder().nOut(4 * 4 * 1024)
                        .build(), "g6")
                .addVertex("reshape",new ReshapeVertex(batch,4, 4, 1024),"g7")
                .addLayer("g8", new ConvolutionLayer.Builder().kernelSize(3,3).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512 * 8 * 8).activation(new ActivationLReLU(0.1)).build(), "reshape")
                //图片复原
                .addVertex("reshape1",new ReshapeVertex(batch,8, 8, 512),"g8")
                .addLayer("g9", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256 * 4).activation(new ActivationLReLU(0.1)).build(), "reshape1")
                .addLayer("g10", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128 * 4).activation(new ActivationLReLU(0.1)).build(), "g9")
                .addLayer("g11", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64 * 4).activation(new ActivationLReLU(0.1)).build(), "g10")

                .addLayer("g12", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64 * 4).activation(new ActivationLReLU(0.2)).build(), "g11")
                .addLayer("g13", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64 * 4).build(), "g12")
                .addVertex("add1",new ElementWiseVertex(ElementWiseVertex.Op.Add),"g13","g11")
                .addLayer("g14", new DenseLayer.Builder().nOut(64 * 4).activation(new ActivationLReLU(0.2))
                        .build(), "add1")


                .addVertex("reshape2",new ReshapeVertex(batch,64, 64, 512),"g14")
                .addLayer("g15", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).activation(new ActivationLReLU(0.2)).build(), "reshape2")
                .addLayer("g16", new ConvolutionLayer.Builder(new int[]{3, 3}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "g15")
                .addVertex("add2",new ElementWiseVertex(ElementWiseVertex.Op.Add),"g16","reshape2")
                .addLayer("g17", new DenseLayer.Builder().nOut(64).activation(new ActivationLReLU(0.2))
                        .build(), "add2")

                .addVertex("reshape3",new ReshapeVertex(batch,64, 64, 64),"g17")
                .addLayer("g18", new ConvolutionLayer.Builder(new int[]{5, 5}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1).activation(Activation.SIGMOID).build(), "reshape3")
                .addLayer("g19", new ConvolutionLayer.Builder(new int[]{5, 5}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(3).activation(Activation.TANH).build(), "g18")
                .addVertex("merge2",new MergeVertex(),"g18","g19")

                .addVertex("stack", new StackVertex(), "input2", "merge2")
                // dis
                .addLayer("d1", new ConvolutionLayer.Builder(new int[]{4,4},new int[]{2,2}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).activation(new ActivationLReLU(0.2)).build(), "stack")
                .addLayer("d2", new ConvolutionLayer.Builder(new int[]{4,4},new int[]{2,2}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128).activation(new ActivationLReLU(0.2)).build(), "d1")
                .addLayer("d3", new ConvolutionLayer.Builder(new int[]{4,4},new int[]{2,2}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256).activation(new ActivationLReLU(0.2)).build(), "d2")
                .addLayer("out", new ConvolutionLayer.Builder(new int[]{4, 4}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1).activation(Activation.SIGMOID).build(), "d3")
                .setOutputs("out")
                .build();
        ComputationGraph model = new ComputationGraph(config);
        model.init();
        return model;
    }
}

